# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from typing import Any, Dict, List, Optional, Union

from typing_extensions import override

from distilabel.steps import GlobalStep, StepInput
from distilabel.steps.tasks.base import Task
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.typing import StepOutput


class ArenaHard(Task):
    """Evaluates two assistant responses using an LLM as judge.

    This `Task` is based on the "From Live Data to High-Quality Benchmarks: The
    Arena-Hard Pipeline" paper that presents Arena Hard, which is a benchmark for
    instruction-tuned LLMs that contains 500 challenging user queries. GPT-4 is used
    as the judge to compare the model responses against a baseline model, which defaults
    to `gpt-4-0314`.

    Note:
        Arena-Hard-Auto has the highest correlation and separability to Chatbot Arena
        among popular open-ended LLM benchmarks.

    Input columns:
        - instruction (`str`): The instruction to evaluate the responses.
        - generations (`List[str]`): The responses generated by two, and only two, LLMs.

    Output columns:
        - evaluation (`str`): The evaluation of the responses generated by the LLMs.
        - score (`str`): The score extracted from the evaluation.
        - model_name (`str`): The model name used to generate the evaluation.

    Categories:
        - benchmark

    References:
        - [From Live Data to High-Quality Benchmarks: The Arena-Hard Pipeline](https://lmsys.org/blog/2024-04-19-arena-hard/)
        - [`arena-hard-auto`](https://github.com/lm-sys/arena-hard-auto/tree/main)

    Examples:

        Evaluate two assistant responses for a given instruction using Arean Hard prompts:

        ```python
        from distilabel.pipeline import Pipeline
        from distilabel.steps import GroupColumns, LoadDataFromDicts
        from distilabel.steps.tasks import ArenaHard, TextGeneration

        with Pipeline() as pipeline:
            load_data = LoadDataFromDicts(
                data=[{"instruction": "What is the capital of France?"}],
            )

            text_generation_a = TextGeneration(
                llm=...,  # LLM instance
                output_mappings={"model_name": "generation_model"},
            )

            text_generation_b = TextGeneration(
                llm=...,  # LLM instance
                output_mappings={"model_name": "generation_model"},
            )

            combine = GroupColumns(
                columns=["generation", "generation_model"],
                output_columns=["generations", "generation_models"],
            )

            arena_hard = ArenaHard(
                llm=...,  # LLM instance
            )

            load_data >> [text_generation_a, text_generation_b] >> combine >> arena_hard
        ```
    """

    @property
    def inputs(self) -> List[str]:
        """The inputs required by this task are the `instruction` and the `generations`,
        which are the responses generated by two, and only two, LLMs."""
        return ["instruction", "generations"]

    def format_input(self, input: Dict[str, Any]) -> ChatType:
        """This method formats the input data as a `ChatType` using the prompt defined
        by the Arena Hard benchmark, which consists on a `system_prompt` plus a template
        for the user first message that contains the `instruction` and both `generations`.
        """
        return [
            {
                "role": "system",
                "content": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user prompt displayed below. You will be given assistant A's answer and assistant B's answer. Your job is to evaluate which assistant's answer is better.\n\nBegin your evaluation by generating your own answer to the prompt. You must provide your answers before judging any answers.\n\nWhen evaluating the assistants' answers, compare both assistants' answers with your answer. You must identify and correct any mistakes or inaccurate information.\n\nThen consider if the assistant's answers are helpful, relevant, and concise. Helpful means the answer correctly responds to the prompt or follows the instructions. Note when user prompt has any ambiguity or more than one interpretation, it is more helpful and appropriate to ask for clarifications or more information from the user than providing an answer based on assumptions. Relevant means all parts of the response closely connect or are appropriate to what is being asked. Concise means the response is clear and not verbose or excessive.\n\nThen consider the creativity and novelty of the assistant's answers when needed. Finally, identify any missing important information in the assistants' answers that would be beneficial to include when responding to the user prompt.\n\nAfter providing your explanation, you must output only one of the following choices as your final verdict with a label:\n\n1. Assistant A is significantly better: [[A>>B]]\n2. Assistant A is slightly better: [[A>B]]\n3. Tie, relatively the same: [[A=B]]\n4. Assistant B is slightly better: [[B>A]]\n5. Assistant B is significantly better: [[B>>A]]\n\nExample output: \"My final verdict is tie: [[A=B]]\".",
            },
            {
                "role": "user",
                "content": f"<|User Prompt|>\n{input['instruction']}\n\n<|The Start of Assistant A's Answer|>\n{input['generations'][0]}\n<|The End of Assistant A's Answer|>\n\n<|The Start of Assistant B's Answer|>\n{input['generations'][1]}\n<|The End of Assistant B's Answer|>",
            },
        ]

    @property
    def outputs(self) -> List[str]:
        """The outputs generated by this task are the `evaluation`, the `score` and
        the `model_name` (which is automatically injected within the `process` method
        of the parent task)."""
        return ["evaluation", "score", "model_name"]

    def format_output(
        self,
        output: Union[str, None],
        input: Union[Dict[str, Any], None] = None,
    ) -> Dict[str, Any]:
        """This method formats the output generated by the LLM as a Python dictionary
        containing the `evaluation` which is the raw output generated by the LLM (consisting
        of the judge LLM alternate generation for the given instruction, plus an explanation
        on the evaluation of the given responses; plus the `score` extracted from the output.

        Args:
            output: the raw output of the LLM.
            input: the input to the task. Is provided in case it needs to be used to enrich
                the output if needed.

        Returns:
            A dict with the keys `evaluation` with the raw output which contains the LLM
            evaluation and the extracted `score` if possible.
        """
        if output is None:
            return {"evaluation": None, "score": None}
        pattern = re.compile(r"\[\[([AB<>=]+)\]\]")
        match = pattern.search(output)
        if match is None:
            return {"evaluation": output, "score": None}
        return {"evaluation": output, "score": match.group(1)}


class ArenaHardResults(GlobalStep):
    """Process Arena Hard results to calculate the ELO scores.

    This `Step` is based on the "From Live Data to High-Quality Benchmarks: The
    Arena-Hard Pipeline" paper that presents Arena Hard, which is a benchmark for
    instruction-tuned LLMs that contains 500 challenging user queries. This step is
    a `GlobalStep` that should run right after the `ArenaHard` task to calculate the
    ELO scores for the evaluated models.

    Note:
        Arena-Hard-Auto has the highest correlation and separability to Chatbot Arena
        among popular open-ended LLM benchmarks.

    Input columns:
        - evaluation (`str`): The evaluation of the responses generated by the LLMs.
        - score (`str`): The score extracted from the evaluation.

    References:
        - [From Live Data to High-Quality Benchmarks: The Arena-Hard Pipeline](https://lmsys.org/blog/2024-04-19-arena-hard/)
        - [`arena-hard-auto`](https://github.com/lm-sys/arena-hard-auto/tree/main)

    Examples:

        Rate the ELO scores for two assistant responses for a given an evaluation / comparison between both using Arean Hard prompts:

        ```python
        from distilabel.pipeline import Pipeline
        from distilabel.steps import GroupColumns, LoadDataFromDicts
        from distilabel.steps.tasks import ArenaHard, TextGeneration

        with Pipeline() as pipeline:
            load_data = LoadDataFromDicts(
                data=[{"instruction": "What is the capital of France?"}],
            )

            text_generation_a = TextGeneration(
                llm=...,  # LLM instance
                output_mappings={"model_name": "generation_model"},
            )

            text_generation_b = TextGeneration(
                llm=...,  # LLM instance
                output_mappings={"model_name": "generation_model"},
            )

            combine = GroupColumns(
                columns=["generation", "generation_model"],
                output_columns=["generations", "generation_models"],
            )

            arena_hard = ArenaHard(
                llm=...,  # LLM instance
            )

            arena_hard_results = ArenaHardResults(
                custom_model_column="generation_models",
                custom_weights={"A>B": 1, "A>>B": 3, "B>A": 1, "B>>A": 3},
            )

            load_data >> [text_generation_a, text_generation_b] >> combine >> arena_hard >> arena_hard_results
        ```

    """

    custom_model_column: Optional[str] = None
    custom_weights: Dict[str, int] = {"A>B": 1, "A>>B": 3, "B>A": 1, "B>>A": 3}

    def load(self) -> None:
        """Ensures that the required dependencies are installed."""
        super().load()

        try:
            import numpy as np  # noqa: F401
            import pandas as pd  # noqa: F401
            from sklearn.linear_model import LogisticRegression  # noqa: F401
        except ImportError as e:
            raise ImportError(
                "In order to run `ArenaHardResults`, the `arena-hard` extra dependencies"
                " must be installed i.e. `numpy`, `pandas`, and `scikit-learn`.\n"
                "Please install the dependencies by running `pip install distilabel[arena-hard]`."
            ) from e

    # TODO: the `evaluation` is not really required as an input, so it could be removed, since
    # only `score` is used / required
    @property
    def inputs(self) -> List[str]:
        """The inputs required by this step are the `evaluation` and the `score` generated
        by the `ArenaHard` task. Since this step does use the identifiers `model_a` and `model_b`,
        optionally one can set `custom_model_column` to use the model names if existing within
        the input data, ideally this value should be `model_name` if connected from the `ArenaHard`
        step."""
        columns = ["evaluation", "score"]
        if self.custom_model_column:
            columns.append(self.custom_model_column)
        return columns

    @override
    def process(self, inputs: StepInput) -> StepOutput:  # type: ignore
        """This method processes the inputs generated by the `ArenaHard` task to calculate the
        win rates for each of the models to evaluate. Since this step inherits from the `GlobalStep`,
        it will wait for all the input batches to be processed, and then the output will be yielded in
        case there's a follow up step, since this step won't modify the received inputs.

        Args:
            inputs: A list of Python dictionaries with the inputs of the task.

        Yields:
            A list of Python dictionaries with the outputs of the task.

        References:
            - https://github.com/lm-sys/arena-hard-auto/blob/main/show_result.py
        """
        import numpy as np
        import pandas as pd
        from sklearn.linear_model import LogisticRegression

        models = ["A", "B"]
        if self.custom_model_column:
            models = inputs[0][self.custom_model_column]

        # TODO: the battles are only calculated for the first game, even though the official
        # implementation also covers the possibility of a second game (not within the released
        # dataset yet)
        battles = pd.DataFrame()
        for input in inputs:
            output = {
                # TODO: "question_id": input["question_id"],
                "model_a": models[0],
                "model_b": models[1],
            }
            if input["score"] in ["A>B", "A>>B"]:
                output["winner"] = models[0]
                rows = [output] * self.custom_weights[input["score"]]
            elif input["score"] in ["B>A", "B>>A"]:
                output["winner"] = models[1]
                rows = [output] * self.custom_weights[input["score"]]
            elif input["score"] == "A=B":
                output["winner"] = "tie"
                rows = [output]
            else:
                continue

            battles = pd.concat([battles, pd.DataFrame(rows)])

        models = pd.concat([battles["model_a"], battles["model_b"]]).unique()
        models = pd.Series(np.arange(len(models)), index=models)

        battles = pd.concat([battles, battles], ignore_index=True)
        p = len(models.index)
        n = battles.shape[0]

        X = np.zeros([n, p])
        X[np.arange(n), models[battles["model_a"]]] = +np.log(10)
        X[np.arange(n), models[battles["model_b"]]] = -np.log(10)

        Y = np.zeros(n)
        Y[battles["winner"] == "model_a"] = 1.0

        tie_idx = battles["winner"] == "tie"
        tie_idx[len(tie_idx) // 2 :] = False
        Y[tie_idx] = 1.0

        lr = LogisticRegression(fit_intercept=False, penalty=None, tol=1e-8)  # type: ignore
        lr.fit(X, Y)

        # The ELO scores are calculated assuming that the reference is `gpt-4-0314`
        # with an starting ELO of 1000, so that the evaluated models are compared with
        # `gtp-4-0314` only if it's available within the models
        elo_scores = 400 * lr.coef_[0] + 1000
        # TODO: we could parametrize the reference / anchor model, but left as is to be faithful to the
        # original implementation
        if "gpt-4-0314" in models.index:
            elo_scores += 1000 - elo_scores[models["gpt-4-0314"]]

        output = pd.Series(elo_scores, index=models.index).sort_values(ascending=False)
        self._logger.info(f"Arena Hard ELO: {output}")

        # Here only so that if follow up steps are connected the inputs are preserved,
        # since this step doesn't modify nor generate new inputs
        yield inputs


if __name__ == "__main__":
    import json

    from distilabel.models import InferenceEndpointsLLM, OpenAILLM
    from distilabel.pipeline import Pipeline
    from distilabel.steps import (
        GroupColumns,
        KeepColumns,
        LoadDataFromHub,
        StepInput,
        step,
    )
    from distilabel.steps.tasks import TextGeneration
    from distilabel.steps.typing import StepOutput

    @step(inputs=["turns"], outputs=["system_prompt", "instruction"])
    def PrepareForTextGeneration(*inputs: StepInput) -> StepOutput:
        for input in inputs:
            for item in input:
                item["system_prompt"] = "You are a helpful assistant."
                item["instruction"] = item["turns"][0]["content"]
            yield input

    @step(
        inputs=["question_id"],
        outputs=["generation", "generation_model"],
        step_type="global",
    )
    def LoadReference(*inputs: StepInput) -> StepOutput:
        # File downloaded from https://raw.githubusercontent.com/lm-sys/arena-hard-auto/e0a8ea1df42c1df76451a6cd04b14e31ff992b87/data/arena-hard-v0.1/model_answer/gpt-4-0314.jsonl
        lines = open("gpt-4-0314.jsonl", mode="r").readlines()
        for input in inputs:
            for item in input:
                for line in lines:
                    data = json.loads(line)
                    if data["question_id"] == item["question_id"]:
                        item["generation"] = data["choices"][0]["turns"][0]["content"]
                        item["generation_model"] = data["model_id"]
                        break
            yield input

    with Pipeline(name="arena-hard-v0.1") as pipeline:
        load_dataset = LoadDataFromHub(
            name="load_dataset",
            repo_id="alvarobartt/lmsys-arena-hard-v0.1",
            split="test",
            num_examples=5,
        )

        load_reference = LoadReference(name="load_reference")

        prepare = PrepareForTextGeneration(name="prepare")

        text_generation_cohere = TextGeneration(
            name="text_generation_cohere",
            llm=InferenceEndpointsLLM(
                model_id="CohereForAI/c4ai-command-r-plus",
                tokenizer_id="CohereForAI/c4ai-command-r-plus",
            ),
            use_system_prompt=True,
            input_batch_size=10,
            output_mappings={"model_name": "generation_model"},
        )

        combine_columns = GroupColumns(
            name="combine_columns",
            columns=["generation", "generation_model"],
            output_columns=["generations", "generation_models"],
        )

        arena_hard = ArenaHard(
            name="arena_hard",
            llm=OpenAILLM(model="gpt-4-1106-preview"),
            output_mappings={"model_name": "evaluation_model"},
        )

        keep_columns = KeepColumns(
            name="keep_columns",
            columns=[
                "question_id",
                "category",
                "cluster",
                "system_prompt",
                "instruction",
                "generations",
                "generation_models",
                "evaluation",
                "score",
                "evaluation_model",
            ],
        )

        win_rates = ArenaHardResults(
            name="win_rates", custom_model_column="generation_models"
        )

        load_dataset >> load_reference  # type: ignore
        load_dataset >> prepare >> text_generation_cohere  # type: ignore
        (  # type: ignore
            [load_reference, text_generation_cohere]
            >> combine_columns
            >> arena_hard
            >> keep_columns
            >> win_rates
        )

        distiset = pipeline.run(
            parameters={  # type: ignore
                text_generation_cohere.name: {
                    "llm": {
                        "generation_kwargs": {
                            "temperature": 0.7,
                            "max_new_tokens": 4096,
                            "stop_sequences": ["<EOS_TOKEN>", "<|END_OF_TURN_TOKEN|>"],
                        }
                    }
                },
                arena_hard.name: {
                    "llm": {
                        "generation_kwargs": {
                            "temperature": 0.0,
                            "max_new_tokens": 4096,
                        }
                    }
                },
            },
        )
        if distiset is not None:
            distiset.push_to_hub("arena-hard-results")
